{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extracting cropped images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from scipy import misc\n",
    "from random import shuffle, random\n",
    "import cv2\n",
    "import os\n",
    "import six.moves.urllib as urllib\n",
    "import tarfile\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from time import gmtime, strftime\n",
    "import yaml\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Initializing Tensorflow detection API\n",
    "from object_detection.utils import label_map_util\n",
    "from object_detection.utils import visualization_utils as vis_util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class DetectionObj(object):\n",
    "    \"\"\"\n",
    "    DetectionObj is a class suitable to leverage Google Tensorflow\n",
    "    detection API for image annotation from different sources:\n",
    "    files, images acquired by own's webcam, videos.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model='ssd_mobilenet_v1_coco_11_06_2017'):\n",
    "        \"\"\"\n",
    "        The instructions to be run when the class is instantiated\n",
    "        \"\"\"\n",
    "\n",
    "        # Path where the Python script is being run\n",
    "        self.CURRENT_PATH = os.getcwd()\n",
    "\n",
    "        # Path where to save the annotations (it can be modified)\n",
    "        self.TARGET_PATH = self.CURRENT_PATH\n",
    "\n",
    "        # Selection of pre-trained detection models\n",
    "        # from the Tensorflow Model Zoo\n",
    "        self.MODELS = [\"ssd_mobilenet_v1_coco_11_06_2017\",\n",
    "                       \"ssd_inception_v2_coco_11_06_2017\",\n",
    "                       \"rfcn_resnet101_coco_11_06_2017\",\n",
    "                       \"faster_rcnn_resnet101_coco_11_06_2017\",\n",
    "                       \"faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017\"\n",
    "                       ]\n",
    "\n",
    "        # Setting a threshold for detecting an object by the models\n",
    "        self.THRESHOLD = 0.25 # Most used threshold in practice\n",
    "\n",
    "        # Checking if the desired pre-trained detection model is available\n",
    "        if model in self.MODELS:\n",
    "            self.MODEL_NAME = model\n",
    "        else:\n",
    "            # Otherwise revert to a default model\n",
    "            print(\"Model not available, reverted to default\", self.MODELS[0])\n",
    "            self.MODEL_NAME = self.MODELS[0]\n",
    "\n",
    "        # The file name of the Tensorflow frozen model\n",
    "        self.CKPT_FILE = os.path.join(self.CURRENT_PATH, 'object_detection',\n",
    "                                      self.MODEL_NAME, 'frozen_inference_graph.pb')\n",
    "\n",
    "        # Attempting loading the detection model, if not available on disk,\n",
    "        # it will be downloaded from Internet(an Internet connection is required)\n",
    "        try:\n",
    "            self.DETECTION_GRAPH = self.load_frozen_model()\n",
    "        except:\n",
    "            print ('Couldn\\'t find', self.MODEL_NAME)\n",
    "            self.download_frozen_model()\n",
    "            self.DETECTION_GRAPH = self.load_frozen_model()\n",
    "\n",
    "        # Loading the labels of the classes recognized by the detection model\n",
    "        self.NUM_CLASSES = 90\n",
    "        path_to_labels = os.path.join(self.CURRENT_PATH,\n",
    "                                      'object_detection', 'data', 'mscoco_label_map.pbtxt')\n",
    "        label_mapping = label_map_util.load_labelmap(path_to_labels)\n",
    "        extracted_categories = label_map_util.convert_label_map_to_categories(label_mapping,\n",
    "                                                                    max_num_classes=self.NUM_CLASSES,\n",
    "                                                                    use_display_name=True)\n",
    "        self.LABELS = {item['id']: item['name'] for item in extracted_categories}\n",
    "        self.CATEGORY_INDEX = label_map_util.create_category_index(extracted_categories)\n",
    "\n",
    "        # Starting the tensorflow session\n",
    "        self.TF_SESSION = tf.Session(graph=self.DETECTION_GRAPH)\n",
    "\n",
    "    def load_frozen_model(self):\n",
    "        \"\"\"\n",
    "        Loading frozen detection model in ckpt file from disk to memory \n",
    "        \"\"\"\n",
    "        detection_graph = tf.Graph()\n",
    "        with detection_graph.as_default():\n",
    "            od_graph_def = tf.GraphDef()\n",
    "            with tf.gfile.GFile(self.CKPT_FILE, 'rb') as fid:\n",
    "                serialized_graph = fid.read()\n",
    "                od_graph_def.ParseFromString(serialized_graph)\n",
    "                tf.import_graph_def(od_graph_def, name='')\n",
    "        return detection_graph\n",
    "\n",
    "    def download_frozen_model(self):\n",
    "        \"\"\"\n",
    "        Downloading frozen detection model from Internet \n",
    "        when not available on disk \n",
    "        \"\"\"\n",
    "        def my_hook(t):\n",
    "            \"\"\"\n",
    "            Wrapping tqdm instance in order to monitor URLopener  \n",
    "            \"\"\"\n",
    "            last_b = [0]\n",
    "\n",
    "            def inner(b=1, bsize=1, tsize=None):\n",
    "                if tsize is not None:\n",
    "                    t.total = tsize\n",
    "                t.update((b - last_b[0]) * bsize)\n",
    "                last_b[0] = b\n",
    "\n",
    "            return inner\n",
    "\n",
    "        # Opening the url where to find the model\n",
    "        model_filename = self.MODEL_NAME + '.tar.gz'\n",
    "        download_url = 'http://download.tensorflow.org/models/object_detection/'\n",
    "        opener = urllib.request.URLopener()\n",
    "\n",
    "        # Downloading the model with tqdm estimations of completion\n",
    "        print('Downloading ...')\n",
    "        with tqdm() as t:\n",
    "            opener.retrieve(download_url + model_filename,\n",
    "                            model_filename, reporthook=my_hook(t))\n",
    "\n",
    "        # Extracting the model from the downloaded tar file\n",
    "        print ('Extracting ...')\n",
    "        tar_file = tarfile.open(model_filename)\n",
    "        for file in tar_file.getmembers():\n",
    "            file_name = os.path.basename(file.name)\n",
    "            if 'frozen_inference_graph.pb' in file_name:\n",
    "                tar_file.extract(file, os.path.join(self.CURRENT_PATH,\n",
    "                                                    'object_detection'))\n",
    "\n",
    "    def load_image_from_disk(self, image_path):\n",
    "        \"\"\"\n",
    "        Loading an image from disk\n",
    "        \"\"\"\n",
    "        return Image.open(image_path)\n",
    "\n",
    "    def load_image_into_numpy_array(self, image):\n",
    "        \"\"\"\n",
    "        Turning an image into a Numpy ndarray\n",
    "        \"\"\"\n",
    "        try:\n",
    "            (im_width, im_height) = image.size\n",
    "            return np.array(image.getdata()).reshape(\n",
    "                (im_height, im_width, 3)).astype(np.uint8)\n",
    "        except:\n",
    "            # If the previous procedure fails, we expect the\n",
    "            # image is already a Numpy ndarray\n",
    "            return image\n",
    "\n",
    "    def detect(self, images, annotate_on_image=True):\n",
    "        \"\"\"\n",
    "        Processing a list of images, feeding it into the detection\n",
    "        model and getting from it scores, bounding boxes and predicted\n",
    "        classes present in the images\n",
    "        \"\"\"\n",
    "        if type(images) is not list:\n",
    "            images = [images]\n",
    "        results = list()\n",
    "        for image in images:\n",
    "            # the array based representation of the image will be used later in order to prepare the\n",
    "            # result image with boxes and labels on it.\n",
    "            image_np = self.load_image_into_numpy_array(image)\n",
    "            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]\n",
    "            image_np_expanded = np.expand_dims(image_np, axis=0)\n",
    "            image_tensor = self.DETECTION_GRAPH.get_tensor_by_name('image_tensor:0')\n",
    "            # Each box represents a part of the image where a particular object was detected.\n",
    "            boxes = self.DETECTION_GRAPH.get_tensor_by_name('detection_boxes:0')\n",
    "            # Each score represent how level of confidence for each of the objects.\n",
    "            # Score could be shown on the result image, together with the class label.\n",
    "            scores = self.DETECTION_GRAPH.get_tensor_by_name('detection_scores:0')\n",
    "            classes = self.DETECTION_GRAPH.get_tensor_by_name('detection_classes:0')\n",
    "            num_detections = self.DETECTION_GRAPH.get_tensor_by_name('num_detections:0')\n",
    "            # Actual detection happens here\n",
    "            (boxes, scores, classes, num_detections) = self.TF_SESSION.run(\n",
    "                [boxes, scores, classes, num_detections],\n",
    "                feed_dict={image_tensor: image_np_expanded})\n",
    "            if annotate_on_image:\n",
    "                new_image = self.detection_on_image(image_np, boxes, scores, classes)\n",
    "                results.append((new_image, boxes, scores, classes, num_detections))\n",
    "            else:\n",
    "                results.append((image_np, boxes, scores, classes, num_detections))\n",
    "        return results\n",
    "    \n",
    "    def detection_on_image(self, image_np, boxes, scores, classes):\n",
    "        \"\"\"\n",
    "        Overimposing detection boxes on the images over the detected classes: \n",
    "        \"\"\"\n",
    "        vis_util.visualize_boxes_and_labels_on_image_array(\n",
    "            image_np,\n",
    "            np.squeeze(boxes),\n",
    "            np.squeeze(classes).astype(np.int32),\n",
    "            np.squeeze(scores),\n",
    "            self.CATEGORY_INDEX,\n",
    "            use_normalized_coordinates=True,\n",
    "            line_thickness=8)\n",
    "        return image_np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def visualize(data, saving=False):\n",
    "    \"\"\"visualizing and saving an image\"\"\"\n",
    "    img = Image.fromarray(data, 'RGB')\n",
    "    if saving:\n",
    "        img.save('my.png')\n",
    "    img.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def read_image_from_disk(filename):\n",
    "    \"\"\"\n",
    "    Reads a png image from disk and \n",
    "    converts it into a Numpy ndarray\n",
    "    \"\"\"\n",
    "    file_contents = misc.imread(filename)\n",
    "    return file_contents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_simulation_training():\n",
    "    \"\"\"Arranging examples from simulator's images\"\"\"\n",
    "    labels, filenames = (list(), list())\n",
    "    for k, light in enumerate(['red', 'yellow', 'green', 'none']):\n",
    "        path = os.path.join(os.getcwd(), 'simulator_images', light)\n",
    "        examples = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]\n",
    "        labels += [light] * len(examples)\n",
    "        filenames += examples\n",
    "    return np.array(filenames), np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_rosbag_training():\n",
    "     \"\"\"Arranging examples from rosbag's images\"\"\"\n",
    "    labels, filenames = (list(), list())\n",
    "    for k, light in enumerate(['red','green']):\n",
    "        path = os.path.join(os.getcwd(), 'ros_bag_images', light)\n",
    "        examples = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]\n",
    "        labels += [light] * len(examples)\n",
    "        filenames += examples\n",
    "    return np.array(filenames), np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_all_bosch_labels(input_yaml, riib=False):\n",
    "    \"\"\" Gets all labels within label file\n",
    "    Note that RGB images are 1280x720 and RIIB images are 1280x736.\n",
    "    :param input_yaml: Path to yaml file\n",
    "    :param riib: If True, change path to labeled pictures\n",
    "    :return: images: Labels for traffic lights\n",
    "    \"\"\"\n",
    "    images = yaml.load(open(input_yaml, 'rb').read())\n",
    "\n",
    "    for i in range(len(images)):\n",
    "        images[i]['path'] = os.path.abspath(os.path.join(os.path.dirname(input_yaml), images[i]['path']))\n",
    "        if riib:\n",
    "            images[i]['path'] = images[i]['path'].replace('.png', '.pgm')\n",
    "            images[i]['path'] = images[i]['path'].replace('rgb/train', 'riib/train')\n",
    "            images[i]['path'] = images[i]['path'].replace('rgb/test', 'riib/test')\n",
    "            for box in images[i]['boxes']:\n",
    "                box['y_max'] = box['y_max'] + 8\n",
    "                box['y_min'] = box['y_min'] + 8\n",
    "    return images\n",
    "\n",
    "def process_bosch_training_data(filename):\n",
    "     \"\"\"Arranging examples from Bosch's images\"\"\"\n",
    "    train_labels = get_all_bosch_labels(filename)\n",
    "    numeric_label = {'R':'red', 'Y':'yellow', 'G':'green', 'V':'null'}\n",
    "    preprocessed_labels = list()\n",
    "    labels = list()\n",
    "    filenames = list()\n",
    "    for n, label in enumerate(train_labels):\n",
    "        preprocessed_labels.append(list())\n",
    "        for box in label['boxes']:\n",
    "            preprocessed_labels[n].append(box['label'])\n",
    "        if len(preprocessed_labels[n]) == 0:\n",
    "            preprocessed_labels[n].append('Void')\n",
    "        color = {item[0] for item in preprocessed_labels[n] if item!='off'}\n",
    "        if len(color) == 1:\n",
    "            filenames.append(label['path'])\n",
    "            labels.append(numeric_label[list(color)[0]])\n",
    "    filenames = np.array(filenames)\n",
    "    labels = np.array(labels)\n",
    "    return filenames, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def extract_bounding_boxes_bosch(target_label='Yellow'):\n",
    "     \"\"\"Arranging examples from simulator's images based on provided bounding boxes\"\"\"\n",
    "    sequence = get_all_bosch_labels('./bosch/train.yaml')\n",
    "    CURRENT_PATH = detector.CURRENT_PATH\n",
    "    source = 'bosch'\n",
    "    for k, item in enumerate(sequence):\n",
    "        image = list()\n",
    "        for box in item['boxes']:\n",
    "            if box['label'] == target_label:\n",
    "                x1, y1, x2, y2 = (box['x_max'], box['y_max'], box['x_min'], box['y_min'])\n",
    "                width  = int((x1 - x2) * 0.2)\n",
    "                height = int((y1 - y2) * 0.2)\n",
    "                if len(image)==0:\n",
    "                    image = read_image_from_disk(item['path'])\n",
    "                cropped_image = image[int(y2)-height:int(y1)+height, int(x2)-width:int(x1)+width, :]\n",
    "                if cropped_image.shape[0] * cropped_image.shape[1] >= 10 * 10:\n",
    "                    resized = misc.imresize(cropped_image, (32, 64))\n",
    "                    saving_path = os.path.join(CURRENT_PATH, 'small_lights', target_label.lower()+'_'+source+'_'+str(int(random()*10**6))+\".jpg\")\n",
    "                    cv2.imwrite(saving_path, cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def crop_traffic_lights(images, labels, source=''):\n",
    "    \"\"\"detecting traffic lights in images and cropping them. Providing an equivalent number of false exampels\"\"\"\n",
    "    results = detector.detect(images, annotate_on_image=False)\n",
    "    for result, label in zip(results, labels):\n",
    "        image_np, boxes, scores, classes, num_detections = result\n",
    "        x, y, _ = image_np.shape\n",
    "\n",
    "        cropped_items = list()\n",
    "        cropped_areas = list()\n",
    "        for k, item in enumerate(classes[0]):\n",
    "            if item== 10:\n",
    "                box, score = boxes[0][k], scores[0][k]\n",
    "                if score > 0.70:\n",
    "                    cropped_items.append(image_np[int(box[0]*x):int(box[2]*x), int(box[1]*y):int(box[3]*y), :])\n",
    "                    cropped_areas.append(box)\n",
    "\n",
    "        empty_examples = []\n",
    "        if cropped_items:\n",
    "            for item, box in zip(cropped_items, cropped_areas):\n",
    "                px, py, _ = item.shape\n",
    "                while 1==1:\n",
    "                    random_point = (int(x*random()), int(y*random()))\n",
    "                    valid = ((random_point[0] < min(box[0], box[2])) | (random_point[0] > max(box[0], box[2]))) | ((random_point[1] < min(box[1], box[3])) | (random_point[1] > max(box[1], box[3])))\n",
    "                    if valid:\n",
    "                        break\n",
    "                empty_examples.append(image_np[random_point[0]:(random_point[0]+px), random_point[1]:(random_point[1]+py), :])\n",
    "        else:\n",
    "            for j in range(3):\n",
    "                px, py = (32*3, 64*3)\n",
    "                random_point = (int(x*random()), int(y*random()))\n",
    "                empty_examples.append(image_np[random_point[0]:(random_point[0]+px), random_point[1]:(random_point[1]+py), :])\n",
    "\n",
    "        CURRENT_PATH = detector.CURRENT_PATH\n",
    "        for item in cropped_items:\n",
    "            resized = misc.imresize(item, (32, 64))\n",
    "            saving_path = os.path.join(CURRENT_PATH, 'small_lights', label+'_'+source+'_'+str(int(random()*10**6))+\".jpg\")\n",
    "            cv2.imwrite(saving_path, cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))\n",
    "        for item in empty_examples:\n",
    "            resized = misc.imresize(item, (32, 64))\n",
    "            saving_path = os.path.join(CURRENT_PATH, 'small_lights', 'none_'+source+'_'+str(int(random()*10**6))+\".jpg\")\n",
    "            cv2.imwrite(saving_path, cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def chunks(l, n):\n",
    "    \"\"\"Yield successive n-sized chunks from l.\"\"\"\n",
    "    for i in range(0, len(l), n):\n",
    "        yield l[i:i + n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# initializing detector\n",
    "detector = DetectionObj('faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# processing simulator images\n",
    "sim_filenames, sim_labels = get_simulation_training()\n",
    "\n",
    "images = [read_image_from_disk(item) for item in sim_filenames]\n",
    "labels = sim_labels\n",
    "crop_traffic_lights(images, labels, source='sim')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# processing rosbag images\n",
    "rosbag_filenames, rosbag_labels = get_rosbag_training()\n",
    "\n",
    "for rb_files, rb_labels in zip(chunks(rosbag_filenames, 5), chunks(rosbag_labels, 5)):\n",
    "    images = [read_image_from_disk(item) for item in rb_files]\n",
    "    labels = rb_labels\n",
    "    crop_traffic_lights(images, labels, source='rosbag')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# processing Bosch dataset\n",
    "bosch_filenames, bosch_labels = process_bosch_training_data('./bosch/train.yaml')\n",
    "\n",
    "for rb_files, rb_labels in zip(chunks(bosch_filenames, 5), chunks(bosch_labels, 5)):\n",
    "    if 'yellow' in rb_labels:\n",
    "        images = [read_image_from_disk(item) for item in rb_files]\n",
    "        labels = rb_labels\n",
    "        crop_traffic_lights(images, labels, source='bosch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# extracting data from Bosch dataset based on the provided bounding boxes\n",
    "extract_bounding_boxes_bosch(target_label='Yellow')\n",
    "extract_bounding_boxes_bosch(target_label='Red')\n",
    "extract_bounding_boxes_bosch(target_label='Green')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
